#!/usr/bin/python3

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import json
import logging
import os
import random

import numpy as np
import torch

from torch.utils.data import DataLoader

from model import KGEModel

from dataloader import TrainDataset
from dataloader import BidirectionalOneShotIterator
import pdb

def parse_args(args=None):
    parser = argparse.ArgumentParser(
        description='Training and Testing Knowledge Graph Embedding Models',
        usage='train.py [<args>] [-h | --help]'
    )

    parser.add_argument('--cuda', action='store_true', help='use GPU')
    
    parser.add_argument('--do_train', action='store_true')
    parser.add_argument('--do_valid', action='store_true')
    parser.add_argument('--do_test', action='store_true')
    parser.add_argument('--evaluate_train', action='store_true', help='Evaluate on training data')
    
    parser.add_argument('--countries', action='store_true', help='Use Countries S1/S2/S3 datasets')
    parser.add_argument('--regions', type=int, nargs='+', default=None, 
                        help='Region Id for Countries S1/S2/S3 datasets, DO NOT MANUALLY SET')
    
    parser.add_argument('--data_path', type=str, default=None)
    parser.add_argument('--model', default='TransE', type=str)
    parser.add_argument('-de', '--double_entity_embedding', action='store_true')
    parser.add_argument('-dr', '--double_relation_embedding', action='store_true')
    
    parser.add_argument('-n', '--negative_sample_size', default=128, type=int)
    parser.add_argument('-d', '--hidden_dim', default=500, type=int)
    parser.add_argument('-g', '--gamma', default=12.0, type=float)
    parser.add_argument('-adv', '--negative_adversarial_sampling', action='store_true')
    parser.add_argument('-a', '--adversarial_temperature', default=1.0, type=float)
    parser.add_argument('-b', '--batch_size', default=1024, type=int)
    parser.add_argument('-r', '--regularization', default=0.0, type=float)
    parser.add_argument('--test_batch_size', default=4, type=int, help='valid/test batch size')
    parser.add_argument('--uni_weight', action='store_true', 
                        help='Otherwise use subsampling weighting like in word2vec')
    
    parser.add_argument('-lr', '--learning_rate', default=0.0001, type=float)
    parser.add_argument('-cpu', '--cpu_num', default=10, type=int)
    parser.add_argument('-init', '--init_checkpoint', default=None, type=str)
    parser.add_argument('-save', '--save_path', default=None, type=str)
    parser.add_argument('--max_steps', default=100000, type=int)
    parser.add_argument('--warm_up_steps', default=None, type=int)
    
    parser.add_argument('--save_checkpoint_steps', default=10000, type=int)
    parser.add_argument('--valid_steps', default=10000, type=int)
    parser.add_argument('--log_steps', default=100, type=int, help='train log every xx steps')
    parser.add_argument('--test_log_steps', default=1000, type=int, help='valid/test log every xx steps')
    
    parser.add_argument('--nentity', type=int, default=0, help='DO NOT MANUALLY SET')
    parser.add_argument('--nrelation', type=int, default=0, help='DO NOT MANUALLY SET')
    parser.add_argument(
    "--alpha", default=0.001, type=float, help="ricci_eta weight" 
    )
    parser.add_argument(
        "--curv_method", default="forman",choices=["none", "forman","olliver"], type=str, help="curv methods" 
    )
    return parser.parse_args(args)
    
def build_edge_index_map(triples):
    """
    Build a dictionary that maps (i, j) or (j, i) to edge index.

    Args:
        triples (list of (h, r, t)): list of triples

    Returns:
        dict: key = (min(i,j), max(i,j)) -> index
    """
    edge_index_map = {}
    for idx, (h, _, t) in enumerate(triples):
        key = (h.item(), t.item())  # undirected edge
        edge_index_map[key] = idx
    return edge_index_map
    
def idx_of(edge_index_map, i, j):
    """
    Find index of edge (i, j) or (j, i) in the triples list.

    Args:
        edge_index_map (dict): from build_edge_index_map()
        i, j (int): node ids

    Returns:
        int: index in triples
    """
    key = (i,j)
    return edge_index_map.get(key, -1)
    
def lowbound_olliver_curvature(triples, dis, neighbors, edge_index_map, num_nodes):
    E = len(triples)
    edge_curvatures = torch.zeros(E, dtype=torch.float32)
    dis=dis.squeeze(1).cuda()
    
    dis_new = torch.exp(-dis)
    edge_curvatures = 1.0-1.0/dis_new
    
    return edge_curvatures
    
def compute_node_forman_curvature_from_distances(triples, dis, neighbors, edge_index_map, num_nodes, aggregate='sum'):
    """
    Compute node-level Forman-Ricci curvature given distances for each edge.

    Args:
    - triples: List of (h, r, t) triples
    - dis_dict: Dict mapping (h, t) or (t, h) to distance d_ht
    - num_nodes: Total number of nodes
    - aggregate: 'sum' or 'mean' aggregation of incident edge curvatures

    Returns:
    - kappa_dict: torch.Tensor of shape (num_nodes,) with node curvature
    """
    # Build neighbor list and distance mapping
    
    E = len(triples)
    edge_curvatures = torch.zeros(E, dtype=torch.float32)
    ww = torch.zeros(num_nodes, dtype=torch.float32)
    dis=dis.squeeze(1).cuda()
    
    dis_new = torch.exp(-dis)
    for i in range(num_nodes):
        nh = [n for n in neighbors[i] if n != i]
        w_hj = torch.tensor([dis_new[idx] for j in nh for idx in [idx_of(edge_index_map, i, j)] if idx != -1]) if nh else torch.tensor([])
        sum_wh = w_hj.sum() if len(w_hj) > 0 else 0
        ww[i]=sum_wh
    
    h_idx = triples[:, 0].tolist()
    t_idx = triples[:, 2].tolist()
    w_ht = dis_new.cuda()  # shape: [E]
    sum_wh = ww[h_idx].cuda()  # shape: [E]
    sum_wt = ww[t_idx].cuda()
    count = torch.zeros(num_nodes).cuda()
    for idx, (h, _, t) in enumerate(triples):
        count[h] += 1    
        count[t] += 1
    count[count == 0] = 1
    
    kappa_ht = w_ht * (2 - sum_wh /count[h_idx]/ w_ht - sum_wt / w_ht/count[t_idx])
    

    return kappa_ht
    
def build_curvature_lookup(triples, edge_curvatures):
    """
    Build a dictionary mapping (h, r, t) triple -> Forman-Ricci curvature.

    Args:
        triples: torch.Tensor of shape (E, 3)
        edge_curvatures: torch.Tensor of shape (E,)

    Returns:
        dict: {(h, r, t): curvature_value}
    """
    triples = triples.cpu().tolist()
    edge_curvatures = edge_curvatures.cpu().tolist()

    curvature_lookup = {
        tuple(triple): kappa
        for triple, kappa in zip(triples, edge_curvatures)
    }

    return curvature_lookup
def override_config(args):
    '''
    Override model and data configuration
    '''
    
    with open(os.path.join(args.init_checkpoint, 'config.json'), 'r') as fjson:
        argparse_dict = json.load(fjson)
    
    args.countries = argparse_dict['countries']
    if args.data_path is None:
        args.data_path = argparse_dict['data_path']
    args.model = argparse_dict['model']
    args.double_entity_embedding = argparse_dict['double_entity_embedding']
    args.double_relation_embedding = argparse_dict['double_relation_embedding']
    args.hidden_dim = argparse_dict['hidden_dim']
    args.test_batch_size = argparse_dict['test_batch_size']
    
def save_model(model, optimizer, save_variable_list, args):
    '''
    Save the parameters of the model and the optimizer,
    as well as some other variables such as step and learning_rate
    '''
    
    argparse_dict = vars(args)
    with open(os.path.join(args.save_path, 'config.json'), 'w') as fjson:
        json.dump(argparse_dict, fjson)

    torch.save({
        **save_variable_list,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()},
        os.path.join(args.save_path, 'checkpoint')
    )
    
    entity_embedding = model.entity_embedding.detach().cpu().numpy()
    np.save(
        os.path.join(args.save_path, 'entity_embedding'), 
        entity_embedding
    )
    
    relation_embedding = model.relation_embedding.detach().cpu().numpy()
    np.save(
        os.path.join(args.save_path, 'relation_embedding'), 
        relation_embedding
    )

def read_triple(file_path, entity2id, relation2id):
    '''
    Read triples and map them into ids.
    '''
    triples = []
    with open(file_path) as fin:
        for line in fin:
            h, r, t = line.strip().split('\t')
            triples.append((entity2id[h], relation2id[r], entity2id[t]))
    return triples

def set_logger(args):
    '''
    Write logs to checkpoint and console
    '''

    if args.do_train:
        log_file = os.path.join(args.save_path or args.init_checkpoint, 'train.log')
    else:
        log_file = os.path.join(args.save_path or args.init_checkpoint, 'test.log')

    logging.basicConfig(
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S',
        filename=log_file,
        filemode='w'
    )
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)

def log_metrics(mode, step, metrics):
    '''
    Print the evaluation logs
    '''
    for metric in metrics:
        logging.info('%s %s at step %d: %f' % (mode, metric, step, metrics[metric]))
        
        
def main(args):
    if (not args.do_train) and (not args.do_valid) and (not args.do_test):
        raise ValueError('one of train/val/test mode must be choosed.')
    
    if args.init_checkpoint:
        override_config(args)
    elif args.data_path is None:
        raise ValueError('one of init_checkpoint/data_path must be choosed.')

    if args.do_train and args.save_path is None:
        raise ValueError('Where do you want to save your trained model?')
    
    if args.save_path and not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    
    # Write logs to checkpoint and console
    set_logger(args)
    
    with open(os.path.join(args.data_path, 'entities.dict')) as fin:
        entity2id = dict()
        for line in fin:
            eid, entity = line.strip().split('\t')
            entity2id[entity] = int(eid)

    with open(os.path.join(args.data_path, 'relations.dict')) as fin:
        relation2id = dict()
        for line in fin:
            rid, relation = line.strip().split('\t')
            relation2id[relation] = int(rid)
    
    # Read regions for Countries S* datasets
    if args.countries:
        regions = list()
        with open(os.path.join(args.data_path, 'regions.list')) as fin:
            for line in fin:
                region = line.strip()
                regions.append(entity2id[region])
        args.regions = regions

    nentity = len(entity2id)
    nrelation = len(relation2id)
    
    args.nentity = nentity
    args.nrelation = nrelation
    
    logging.info('Model: %s' % args.model)
    logging.info('Data Path: %s' % args.data_path)
    logging.info('#entity: %d' % nentity)
    logging.info('#relation: %d' % nrelation)
    
    train_triples = read_triple(os.path.join(args.data_path, 'train.txt'), entity2id, relation2id)
    logging.info('#train: %d' % len(train_triples))
    valid_triples = read_triple(os.path.join(args.data_path, 'valid.txt'), entity2id, relation2id)
    logging.info('#valid: %d' % len(valid_triples))
    test_triples = read_triple(os.path.join(args.data_path, 'test.txt'), entity2id, relation2id)
    logging.info('#test: %d' % len(test_triples))
    
    #All true triples
    all_true_triples = train_triples + valid_triples + test_triples
    
    
    kge_model = KGEModel(
        model_name=args.model,
        nentity=nentity,
        nrelation=nrelation,
        hidden_dim=args.hidden_dim,
        gamma=args.gamma,
        double_entity_embedding=args.double_entity_embedding,
        double_relation_embedding=args.double_relation_embedding
    )
    
    logging.info('Model Parameter Configuration:')
    for name, param in kge_model.named_parameters():
        logging.info('Parameter %s: %s, require_grad = %s' % (name, str(param.size()), str(param.requires_grad)))

    if args.cuda:
        kge_model = kge_model.cuda()
    
    
    if args.do_train:
        # Set training dataloader iterator
        train_dataloader_head = DataLoader(
            TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, 'head-batch'), 
            batch_size=args.batch_size,
            shuffle=True, 
            num_workers=max(1, args.cpu_num//2),
            collate_fn=TrainDataset.collate_fn
        )
        
        train_dataloader_tail = DataLoader(
            TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, 'tail-batch'), 
            batch_size=args.batch_size,
            shuffle=True, 
            num_workers=max(1, args.cpu_num//2),
            collate_fn=TrainDataset.collate_fn
        )
        
        train_iterator = BidirectionalOneShotIterator(train_dataloader_head, train_dataloader_tail)
        
        # Set training configuration
        current_learning_rate = args.learning_rate
        optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, kge_model.parameters()), 
            lr=current_learning_rate
        )
        if args.warm_up_steps:
            warm_up_steps = args.warm_up_steps
        else:
            warm_up_steps = args.max_steps // 2

    if args.init_checkpoint:
        # Restore model from checkpoint directory
        logging.info('Loading checkpoint %s...' % args.init_checkpoint)
        checkpoint = torch.load(os.path.join(args.init_checkpoint, 'checkpoint'))
        init_step = checkpoint['step']
        kge_model.load_state_dict(checkpoint['model_state_dict'])
        if args.do_train:
            current_learning_rate = checkpoint['current_learning_rate']
            warm_up_steps = checkpoint['warm_up_steps']
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        logging.info('Ramdomly Initializing %s Model...' % args.model)
        init_step = 0
    
    step = init_step
    
    logging.info('Start Training...')
    logging.info('init_step = %d' % init_step)
    logging.info('batch_size = %d' % args.batch_size)
    logging.info('negative_adversarial_sampling = %d' % args.negative_adversarial_sampling)
    logging.info('hidden_dim = %d' % args.hidden_dim)
    logging.info('gamma = %f' % args.gamma)
    logging.info('negative_adversarial_sampling = %s' % str(args.negative_adversarial_sampling))
    if args.negative_adversarial_sampling:
        logging.info('adversarial_temperature = %f' % args.adversarial_temperature)
    
    # Set valid dataloader as it would be evaluated during training
    train_examples = torch.tensor(train_triples).cuda()
    neighbors = {i: [] for i in range(args.nentity)}
    for h, _, t in train_examples:
        neighbors[h.item()].append(t.item())
        neighbors[t.item()].append(h.item())
    edge_index_map = build_edge_index_map(train_examples)
    
    kappa_edge  = torch.zeros(len(train_examples), dtype=torch.float32)
    curv_up_step = int(len(train_triples)/args.batch_size*5/100)*100
    logging.info('curv method = %s' % args.curv_method)
    logging.info('curv capluing = %f' % args.alpha)
    logging.info('curv updata setp = %f' % curv_up_step)
    
    if args.do_train:
        logging.info('learning_rate = %d' % current_learning_rate)

        training_logs = []
        
        #Training Loop
        for step in range(init_step, args.max_steps):
            
            #pdb.set_trace()
            if step % curv_up_step ==0:
                dis = kge_model(train_examples)
                if args.curv_method=='forman':
                    kappa_edge =  compute_node_forman_curvature_from_distances(train_examples, dis, neighbors, edge_index_map, args.nentity, aggregate='mean')
                elif args.curv_method=='olliver':
                    kappa_edge =lowbound_olliver_curvature(train_examples, dis, neighbors, edge_index_map, args.nentity)
                mean_val = kappa_edge.mean()
                kappa_edge = kappa_edge - mean_val
                variance = kappa_edge.var(unbiased=False)
                #pdb.set_trace()
                print("variance curv",variance)
                pos_mask = kappa_edge > 0
                neg_mask = kappa_edge < 0
                zero_mask = kappa_edge == 0
                n_pos  = pos_mask.sum().item()
                n_neg  = neg_mask.sum().item()
                n_zero = zero_mask.sum().item()
                total = kappa_edge.numel()
                pos_ratio  = n_pos  / total
                neg_ratio  = n_neg  / total
                zero_ratio = n_zero / total
                
                print(f"Positive ratio: {pos_ratio:.4f}, Negative ratio: {neg_ratio:.4f}, Zero ratio: {zero_ratio:.4f}")
                    #pdb.set_trace()
                curvature_lookup = build_curvature_lookup(train_examples, kappa_edge)
                kge_model.reciev_curv(curvature_lookup)
            log = kge_model.train_step(kge_model, optimizer, train_iterator, args)
            
            training_logs.append(log)
            
            if step >= warm_up_steps:
                current_learning_rate = current_learning_rate / 10
                logging.info('Change learning_rate to %f at step %d' % (current_learning_rate, step))
                optimizer = torch.optim.Adam(
                    filter(lambda p: p.requires_grad, kge_model.parameters()), 
                    lr=current_learning_rate
                )
                warm_up_steps = warm_up_steps * 3
            
            if step % args.save_checkpoint_steps == 0:
                save_variable_list = {
                    'step': step, 
                    'current_learning_rate': current_learning_rate,
                    'warm_up_steps': warm_up_steps
                }
                save_model(kge_model, optimizer, save_variable_list, args)
                
            if step % args.log_steps == 0:
                metrics = {}
                for metric in training_logs[0].keys():
                    metrics[metric] = sum([log[metric] for log in training_logs])/len(training_logs)
                log_metrics('Training average', step, metrics)
                training_logs = []
                
            if args.do_valid and step % args.valid_steps == 0:
                logging.info('Evaluating on Valid Dataset...')
                metrics = kge_model.test_step(kge_model, valid_triples, all_true_triples, args)
                log_metrics('Valid', step, metrics)
        
        save_variable_list = {
            'step': step, 
            'current_learning_rate': current_learning_rate,
            'warm_up_steps': warm_up_steps
        }
        save_model(kge_model, optimizer, save_variable_list, args)
        
    if args.do_valid:
        logging.info('Evaluating on Valid Dataset...')
        metrics = kge_model.test_step(kge_model, valid_triples, all_true_triples, args)
        log_metrics('Valid', step, metrics)
    
    if args.do_test:
        logging.info('Evaluating on Test Dataset...')
        metrics = kge_model.test_step(kge_model, test_triples, all_true_triples, args)
        log_metrics('Test', step, metrics)
    
    if args.evaluate_train:
        logging.info('Evaluating on Training Dataset...')
        metrics = kge_model.test_step(kge_model, train_triples, all_true_triples, args)
        log_metrics('Test', step, metrics)
        
if __name__ == '__main__':
    main(parse_args())
